{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "089e364e",
      "metadata": {},
      "source": [
        "# Training with Opacus on multiple GPUs with Distributed Data Parallel"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "afda23db",
      "metadata": {},
      "source": [
        "In this tutorial we'll go over the basics you need to know to start using Opacus in your distributed model training pipeline. As the state-of-the-art models and datasets get bigger, multi-GPU training became the norm and Opacus comes with seamless, out-of-the-box support for Distributed Data Parallel (DDP).\n",
        "\n",
        "This tutorial requires basic knowledge of Opacus and DDP. If you're new to either of these tools, we suggest starting with the following tutorials: [Building an Image Classifier with Differential Privacy](https://opacus.ai/tutorials/building_image_classifier) and [Getting Started with Distributed Data Parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)\n",
        "\n",
        "In Chapter 1 we'll start with a minimal working example to demonstrate what exactly you need to do in order to make Opacus work in a distributed setting. This should be enough to get started for most common scenarios.\n",
        "\n",
        "In Chapters 2 and 3 we'll take a closer look at the implementation and talk about technical details. We'll see what are the differences between private DDP and regular DDP and why we need to introduce them. "
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e04831f6",
      "metadata": {},
      "source": [
        "## Chapter 0: Preparations"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1089c8c1",
      "metadata": {},
      "source": [
        "Before we begin, there are a few things we need to mention.\n",
        "\n",
        "First, this tutorial is written to be executed on a single Linux machine with at least 2 GPUs. The general principles remain the same for Windows environment and/or multi-node training, but you'll need to slightly modify the DDP code to make it work.\n",
        "\n",
        "Second, Jupyter notebooks are [known](https://discuss.pytorch.org/t/distributeddataparallel-on-terminal-vs-jupyter-notebook/101404) not to support DDP training. Throughout the tutorial, we'll use `%%writefile` magic command to write code to a separate file and later execute it via the terminal. These files will be cleaned up in the last cell of this notebook."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a702f3d6",
      "metadata": {},
      "source": [
        "## Chapter 1: Getting Started"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3a9a2af5",
      "metadata": {},
      "source": [
        "First, let's initialise the distributed environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "7189b698",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Overwriting opacus_ddp_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile opacus_ddp_demo.py\n",
        "import os\n",
        "import torch.distributed as dist\n",
        "import logging\n",
        "\n",
        "logger = logging.getLogger(__name__)\n",
        "logger.setLevel(logging.INFO)\n",
        "\n",
        "def setup(rank, world_size):\n",
        "    os.environ['MASTER_ADDR'] = 'localhost'\n",
        "    os.environ['MASTER_PORT'] = '12355'\n",
        "\n",
        "    # initialize the process group\n",
        "    dist.init_process_group(\"gloo\", rank=rank, world_size=world_size)\n",
        "\n",
        "def cleanup():\n",
        "    dist.destroy_process_group()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "36d6ebef",
      "metadata": {},
      "source": [
        "We'll be using MNIST for a toy example, so let's also initialize simple convolutional network and download the dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "336d6560",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_ddp_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_ddp_demo.py\n",
        "\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "class SampleConvNet(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)\n",
        "        self.conv2 = nn.Conv2d(16, 32, 4, 2)\n",
        "        self.fc1 = nn.Linear(32 * 4 * 4, 32)\n",
        "        self.fc2 = nn.Linear(32, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x of shape [B, 1, 28, 28]\n",
        "        x = F.relu(self.conv1(x))  # -> [B, 16, 14, 14]\n",
        "        x = F.max_pool2d(x, 2, 1)  # -> [B, 16, 13, 13]\n",
        "        x = F.relu(self.conv2(x))  # -> [B, 32, 5, 5]\n",
        "        x = F.max_pool2d(x, 2, 1)  # -> [B, 32, 4, 4]\n",
        "        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]\n",
        "        x = F.relu(self.fc1(x))  # -> [B, 32]\n",
        "        x = self.fc2(x)  # -> [B, 10]\n",
        "        return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "254db444",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_ddp_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_ddp_demo.py\n",
        "\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "# Precomputed characteristics of the MNIST dataset\n",
        "MNIST_MEAN = 0.1307\n",
        "MNIST_STD = 0.3081\n",
        "\n",
        "DATA_ROOT = \"./mnist\"\n",
        "\n",
        "mnist_train_ds = datasets.MNIST(\n",
        "    DATA_ROOT,\n",
        "    train=True,\n",
        "    download=True,\n",
        "    transform=transforms.Compose(\n",
        "        [\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),\n",
        "        ]\n",
        "    ),\n",
        ")\n",
        "\n",
        "mnist_test_ds = datasets.MNIST(\n",
        "    DATA_ROOT,\n",
        "    train=False,\n",
        "    download=True,\n",
        "    transform=transforms.Compose(\n",
        "        [\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),\n",
        "        ]\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3b754537",
      "metadata": {},
      "source": [
        "Coming next is the key bit - and the only one that's different from non-private DDP.\n",
        "\n",
        "First, instead of wrapping the model with `DistributedDataParallel` we'll wrap it with `DifferentiallyPrivateDistributedDataParallel` from `opacus.distributed` package. Simple as that.\n",
        "\n",
        "Second difference comes when initializing the `DataLoader`. Normally, for distributed training you would initialize data loader specific to your distributed setup. It affects two parameters:\n",
        "- Batch size denotes the per-GPU batch size. That is, your logical batch size (one that matters for convergence) is equal to `local_batch_size*num_gpus`.\n",
        "- You need to specify `sampler=DistributedSampler(dataset)` to distribute the training dataset across GPUs.\n",
        "\n",
        "With Opacus you don't need to do either of those things. `make_private` method expects user-provided `DataLoader` to be non-distributed, initialized as if you're training on a single GPU. \n",
        "\n",
        "The code below highlights changes you need to make to a normal DDP training pipeline by commenting out lines you need to replace or remove."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "4af4f1d0",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_ddp_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_ddp_demo.py\n",
        "\n",
        "import torch\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import DataLoader, DistributedSampler\n",
        "\n",
        "from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP\n",
        "from torch.nn.parallel import DistributedDataParallel as DDP\n",
        "\n",
        "from opacus import PrivacyEngine\n",
        "\n",
        "LR = 0.1\n",
        "BATCH_SIZE = 200\n",
        "N_GPUS = torch.cuda.device_count()\n",
        "\n",
        "def init_training(rank):\n",
        "    model = SampleConvNet()\n",
        "    #model = DDP(model) -- non-private\n",
        "    model = DPDDP(model)\n",
        "\n",
        "    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0)\n",
        "    data_loader = DataLoader(\n",
        "        mnist_train_ds,\n",
        "        #batch_size=BATCH_SIZE // N_GPUS, -- non-private\n",
        "        batch_size=BATCH_SIZE,\n",
        "        #sampler=DistributedSampler(mnist_train_ds) -- non-private\n",
        "    )\n",
        "\n",
        "    if rank == 0:\n",
        "        logger.info(\n",
        "            f\"(rank {rank}) Initialized model ({type(model).__name__}), \"\n",
        "            f\"optimizer ({type(optimizer).__name__}), \"\n",
        "            f\"data loader ({type(data_loader).__name__}, len={len(data_loader)})\"\n",
        "        )\n",
        "\n",
        "    privacy_engine = PrivacyEngine()\n",
        "\n",
        "    # PrivacyEngine looks at the model's class and enables\n",
        "    # distributed processing if it's wrapped with DPDDP\n",
        "    model, optimizer, data_loader = privacy_engine.make_private(\n",
        "        module=model,\n",
        "        optimizer=optimizer,\n",
        "        data_loader=data_loader,\n",
        "        noise_multiplier=1.,\n",
        "        max_grad_norm=1.,\n",
        "    )\n",
        "\n",
        "    if rank == 0:\n",
        "        logger.info(\n",
        "            f\"(rank {rank}) After privatization: model ({type(model).__name__}), \"\n",
        "            f\"optimizer ({type(optimizer).__name__}), \"\n",
        "            f\"data loader ({type(data_loader).__name__}, len={len(data_loader)})\"\n",
        "        )\n",
        "\n",
        "    logger.info(f\"(rank {rank}) Average batch size per GPU: {int(optimizer.expected_batch_size)}\")\n",
        "\n",
        "    return model, optimizer, data_loader, privacy_engine"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "baae0bb5",
      "metadata": {},
      "source": [
        "Now we just need to define the training loop and launch it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "26c24fc7",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_ddp_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_ddp_demo.py\n",
        "\n",
        "import numpy as np\n",
        "\n",
        "def test(model, device):\n",
        "    test_loader = DataLoader(\n",
        "        mnist_test_ds,\n",
        "        batch_size=BATCH_SIZE,\n",
        "    )\n",
        "\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    with torch.no_grad():\n",
        "        for data, target in test_loader:\n",
        "            data, target = data.to(device), target.to(device)\n",
        "            output = model(data)\n",
        "            pred = output.argmax(\n",
        "                dim=1, keepdim=True\n",
        "            )\n",
        "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
        "\n",
        "    model.train()\n",
        "    return correct / len(mnist_test_ds)\n",
        "\n",
        "def launch(rank, world_size, epochs):\n",
        "    setup(rank, world_size)\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "    model, optimizer, data_loader, privacy_engine = init_training(rank)\n",
        "    model.to(rank)\n",
        "    model.train()\n",
        "\n",
        "\n",
        "    for e in range(epochs):\n",
        "        losses = []\n",
        "        correct = 0\n",
        "        total = 0\n",
        "\n",
        "        for data, target in data_loader:\n",
        "            data, target = data.to(rank), target.to(rank)\n",
        "            optimizer.zero_grad()\n",
        "            output = model(data)\n",
        "\n",
        "            pred = output.argmax(dim=1, keepdim=True)\n",
        "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
        "            total += len(data)\n",
        "\n",
        "            loss = criterion(output, target)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            losses.append(loss.item())\n",
        "\n",
        "        test_accuracy = test(model, rank)\n",
        "        train_accuracy = correct / total\n",
        "        epsilon = privacy_engine.get_epsilon(delta=1e-5)\n",
        "\n",
        "        if rank == 0:\n",
        "            print(\n",
        "                f\"Epoch: {e} \\t\"\n",
        "                f\"Train Loss: {np.mean(losses):.4f} | \"\n",
        "                f\"Train Accuracy: {train_accuracy:.2f} | \"\n",
        "                f\"Test Accuracy: {test_accuracy:.2f} |\"\n",
        "                f\"(ε = {epsilon:.2f})\"\n",
        "            )\n",
        "\n",
        "    cleanup()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "03256ef1",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_ddp_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_ddp_demo.py\n",
        "\n",
        "import torch.multiprocessing as mp\n",
        "\n",
        "EPOCHS = 10\n",
        "world_size = torch.cuda.device_count()\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    mp.spawn(\n",
        "        launch,\n",
        "        args=(world_size,EPOCHS,),\n",
        "        nprocs=world_size,\n",
        "        join=True\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e8703467",
      "metadata": {},
      "source": [
        "And, finally, running the script. Notice, that we've initialized our `DataLoader` with `batch_size=200`, which is equivalent to 300 batches on the full dataset (60000 images). \n",
        "\n",
        "After passing it to `make_private` on each worker we have a data loader with `batch_size=100` each, but each data loader still goes over 300 batches."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "id": "48c97af5",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "05/13/2022 11:13:16:INFO:(rank 0) Initialized model (DifferentiallyPrivateDistributedDataParallel), optimizer (SGD), data loader (DataLoader, len=300)\n",
            "05/13/2022 11:13:16:INFO:(rank 1) Average batch size per GPU: 100\n",
            "05/13/2022 11:13:16:INFO:(rank 0) After privatization: model (GradSampleModule), optimizer (DistributedDPOptimizer), data loader (DPDataLoader, len=300)\n",
            "05/13/2022 11:13:16:INFO:(rank 0) Average batch size per GPU: 100\n",
            "Epoch: 0 \tTrain Loss: 1.5412 | Train Accuracy: 0.57 | Test Accuracy: 0.73 |(ε = 0.87)\n",
            "Epoch: 1 \tTrain Loss: 0.6717 | Train Accuracy: 0.79 | Test Accuracy: 0.83 |(ε = 0.91)\n",
            "Epoch: 2 \tTrain Loss: 0.5659 | Train Accuracy: 0.85 | Test Accuracy: 0.86 |(ε = 0.96)\n",
            "Epoch: 3 \tTrain Loss: 0.5347 | Train Accuracy: 0.87 | Test Accuracy: 0.88 |(ε = 1.00)\n",
            "Epoch: 4 \tTrain Loss: 0.5178 | Train Accuracy: 0.88 | Test Accuracy: 0.90 |(ε = 1.03)\n",
            "Epoch: 5 \tTrain Loss: 0.4750 | Train Accuracy: 0.90 | Test Accuracy: 0.91 |(ε = 1.07)\n",
            "Epoch: 6 \tTrain Loss: 0.4502 | Train Accuracy: 0.90 | Test Accuracy: 0.91 |(ε = 1.11)\n",
            "Epoch: 7 \tTrain Loss: 0.4358 | Train Accuracy: 0.91 | Test Accuracy: 0.92 |(ε = 1.14)\n",
            "Epoch: 8 \tTrain Loss: 0.4186 | Train Accuracy: 0.92 | Test Accuracy: 0.92 |(ε = 1.18)\n",
            "Epoch: 9 \tTrain Loss: 0.4129 | Train Accuracy: 0.92 | Test Accuracy: 0.93 |(ε = 1.21)\n"
          ]
        }
      ],
      "source": [
        "!python -W ignore opacus_ddp_demo.py"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a2844d05",
      "metadata": {},
      "source": [
        "## Chapter 2: Data and distributed sampler"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c86099a0",
      "metadata": {},
      "source": [
        "**Note**: The following two chapters discuss the advanced usage of Opacus and its implementation details. We strongly recommend to read the tutorial on [Advanced Features of Opacus](https://opacus.ai/tutorials/intro_to_advanced_features) before proceeding.\n",
        "\n",
        "Now let's look inside `make_private` method and see what it does to enable DDP processing. And we'll start with the modifications made to the `DataLoader`.\n",
        "\n",
        "As a reminder, `DPDataLoader` is different from a regular `DataLoader` in only one aspect - it samples data with uniform with replacement random sampler (a.k.a. \"Poisson sampling\"). It means, that instead of a fixed batch size we have a sampling rate: a probability with which every sample is included in the next batch."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5ef85e84",
      "metadata": {},
      "source": [
        "Let's now initialize the regular data loader and then transform it to the `DPDataLoader`. This is exactly how we do it in the `make_private()` method.\n",
        "\n",
        "Below we'll initialize three data loaders:\n",
        "- Non-distributed\n",
        "- Distributed, non-private\n",
        "- Distributed, private (with Poisson sampling)\n",
        "\n",
        "All three are initialized so that the logical batch size is 64. \n",
        "\n",
        "Note that the `make_private` method uses `DPDataLoader` only when `poisson_sampling` is set to `True` (which is the default value). If `poisson_sampling` is set to `False` when using `DPDPP` then we need to explicity provide the `DistributedSampler` and `batch_size` as `BATCH_SIZE // world_size` to the `DataLoader` similar to the non-private case."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "id": "ea3bc937",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Writing opacus_distributed_data_loader_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile opacus_distributed_data_loader_demo.py\n",
        "\n",
        "from opacus_ddp_demo import setup, cleanup, mnist_train_ds\n",
        "import logging\n",
        "from torch.utils.data import DataLoader, DistributedSampler\n",
        "from opacus.data_loader import DPDataLoader\n",
        "\n",
        "logger = logging.getLogger(__name__)\n",
        "logger.setLevel(logging.INFO)\n",
        "\n",
        "BATCH_SIZE = 64\n",
        "\n",
        "def init_data(rank, world_size):\n",
        "    setup(rank, world_size)\n",
        "\n",
        "    non_distributed_dl = DataLoader(\n",
        "        mnist_train_ds,\n",
        "        batch_size=BATCH_SIZE\n",
        "    )\n",
        "\n",
        "    distributed_non_private_dl = DataLoader(\n",
        "        mnist_train_ds,\n",
        "        batch_size=BATCH_SIZE // world_size,\n",
        "        sampler=DistributedSampler(mnist_train_ds),\n",
        "    )\n",
        "\n",
        "    private_dl = DPDataLoader.from_data_loader(non_distributed_dl, distributed=True)\n",
        "\n",
        "    if rank == 0:\n",
        "        logger.info(\n",
        "            f\"(rank {rank}) Non-distributed non-private data loader. \"\n",
        "            f\"#batches: {len(non_distributed_dl)}, \"\n",
        "            f\"#data points: {len(non_distributed_dl.sampler)}, \"\n",
        "            f\"batch_size: {non_distributed_dl.batch_size}\"\n",
        "        )\n",
        "\n",
        "        logger.info(\n",
        "            f\"(rank {rank}) Distributed, non-private data loader. \"\n",
        "            f\"#batches: {len(distributed_non_private_dl)}, \"\n",
        "            f\"#data points: {len(distributed_non_private_dl.sampler)}, \"\n",
        "            f\"batch_size: {distributed_non_private_dl.batch_size}\"\n",
        "        )\n",
        "\n",
        "        logger.info(\n",
        "            f\"(rank {rank}) Distributed, private data loader. \"\n",
        "            f\"#batches: {len(private_dl)}, \"\n",
        "            f\"#data points: {private_dl.batch_sampler.num_samples}, \"\n",
        "            f\"sample_rate: {private_dl.sample_rate:4f}, \"\n",
        "            f\"avg batch_size (=sample_rate*num_data_points): {int(private_dl.sample_rate*private_dl.batch_sampler.num_samples)}\"\n",
        "        )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "id": "3dba7efa",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_distributed_data_loader_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_distributed_data_loader_demo.py\n",
        "\n",
        "import torch\n",
        "import torch.multiprocessing as mp\n",
        "\n",
        "world_size = torch.cuda.device_count()\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    mp.spawn(\n",
        "        init_data,\n",
        "        args=(world_size,),\n",
        "        nprocs=world_size,\n",
        "        join=True\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "98adeb98",
      "metadata": {},
      "source": [
        "Let's see what happens when we run it - and what exactly does `from_data_loader` factory did.\n",
        "\n",
        "Notice, that our private DataLoader was initialized with a non-distributed, non-private data loader. And all the basic parameters (per GPU batch size and number of examples per GPU) match with distributed, non-private data loader."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "99ef680c",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "05/13/2022 11:14:53:INFO:(rank 0) Non-distributed non-private data loader. #batches: 938, #data points: 60000, batch_size: 64\n",
            "05/13/2022 11:14:53:INFO:(rank 0) Distributed, non-private data loader. #batches: 938, #data points: 30000, batch_size: 32\n",
            "05/13/2022 11:14:53:INFO:(rank 0) Distributed, private data loader. #batches: 938, #data points: 30000, sample_rate: 0.001066, avg batch_size (=sample_rate*num_data_points): 31\n"
          ]
        }
      ],
      "source": [
        "!python -W ignore opacus_distributed_data_loader_demo.py"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d938f572",
      "metadata": {},
      "source": [
        "## Chapter 3: Synchronisation"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "491a3ed8",
      "metadata": {},
      "source": [
        "One significant difference between `DDP` and `DPDDP` is how it approaches synchronisation.\n",
        "\n",
        "Normally with Distributed Data Parallel forward and backward passes are synchronisation points, and `DDP` wrapper ensures that the gradients are synchronised across workers as soon as they are avaliable for each layer during the backward pass.\n",
        "\n",
        "Opacus, however, need a later synchronisation point. Before we can use the gradients, we need to clip them and add noise. This is done in the optimizer, which moves the synchronisation point from the backward pass to the optimization step.\n",
        "Additionally, to simplify the calculations, we only add noise on worker with `rank=0`, and use the noise scale calibrated to the combined batch across all workers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "id": "9821ce76",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Writing opacus_sync_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile opacus_sync_demo.py\n",
        "\n",
        "import sys\n",
        "sys.path.append('/data/home/shilov/opacus')\n",
        "\n",
        "from opacus_ddp_demo import setup, cleanup, mnist_train_ds, SampleConvNet\n",
        "import logging\n",
        "from torch.utils.data import DataLoader\n",
        "import torch.optim as optim\n",
        "from opacus.data_loader import DPDataLoader\n",
        "from opacus import GradSampleModule\n",
        "from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP\n",
        "from opacus.optimizers import DistributedDPOptimizer\n",
        "from torch.nn.parallel import DistributedDataParallel as DDP\n",
        "\n",
        "logger = logging.getLogger(__name__)\n",
        "logger.setLevel(logging.INFO)\n",
        "\n",
        "BATCH_SIZE = 64\n",
        "LR = 64\n",
        "\n",
        "def init_training(rank, world_size):\n",
        "    model = SampleConvNet()\n",
        "    optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0)\n",
        "\n",
        "    model = GradSampleModule(model)\n",
        "    model = DPDDP(model)\n",
        "\n",
        "    optimizer = DistributedDPOptimizer(\n",
        "        optimizer=optimizer,\n",
        "        noise_multiplier=0.,\n",
        "        max_grad_norm=100.,\n",
        "        expected_batch_size=BATCH_SIZE//world_size,\n",
        "    )\n",
        "\n",
        "    data_loader = DPDataLoader.from_data_loader(\n",
        "        data_loader=DataLoader(\n",
        "            mnist_train_ds,\n",
        "            batch_size=BATCH_SIZE,\n",
        "        ),\n",
        "        distributed=True,\n",
        "    )\n",
        "\n",
        "\n",
        "    return model, optimizer, data_loader"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a1e810cc",
      "metadata": {},
      "source": [
        "Now we've initialized `DifferentiallyPrivateDistributedDataParallel` model and `DistributedDPOptimizer` let's see how they work together.\n",
        "\n",
        "`DifferentiallyPrivateDistributedDataParallel` is a no-op: we only perform model synchronisation on initialization and do nothing on forward and backward passes.\n",
        "\n",
        "`DistributedDPOptimizer`, on the other hand does all the heavy lifting:\n",
        "- It does gradient clipping on each worker independently\n",
        "- It adds noise on worker with `rank=0` only\n",
        "- It calls `torch.distributed.all_reduce` and gradients on `step()`, right before applying the gradients"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "fa600d2f",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_sync_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_sync_demo.py\n",
        "\n",
        "import torch.nn as nn\n",
        "import numpy as np\n",
        "\n",
        "def launch(rank, world_size):\n",
        "    setup(rank, world_size)\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "    model, optimizer, data_loader = init_training(rank, world_size)\n",
        "    model.to(rank)\n",
        "    model.train()\n",
        "\n",
        "    for data, target in data_loader:\n",
        "        data = data\n",
        "        target = torch.tensor(target)\n",
        "\n",
        "        data, target = data.to(rank), target.to(rank)\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        output = model(data)\n",
        "        loss = criterion(output, target)\n",
        "        loss.backward()\n",
        "\n",
        "        flat_grad = torch.cat([p.grad_sample.sum(dim=0).view(-1) for p in model.parameters()]).cpu().numpy() / optimizer.expected_batch_size\n",
        "        logger.info(\n",
        "            f\"(rank={rank}) Gradient norm before optimizer.step(): {np.linalg.norm(flat_grad):.4f}\"\n",
        "        )\n",
        "        logger.info(\n",
        "            f\"(rank={rank}) Gradient sample before optimizer.step(): {flat_grad[:3]}\"\n",
        "        )\n",
        "\n",
        "        optimizer.step()\n",
        "\n",
        "        flat_grad = torch.cat([p.grad.view(-1) for p in model.parameters()]).cpu().numpy()\n",
        "        logger.info(\n",
        "            f\"(rank={rank}) Gradient norm after optimizer.step(): {np.linalg.norm(flat_grad):.4f}\"\n",
        "        )\n",
        "        logger.info(\n",
        "            f\"(rank={rank}) Gradient sample after optimizer.step(): {flat_grad[:3]}\"\n",
        "        )\n",
        "\n",
        "        break\n",
        "\n",
        "    cleanup()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "d249d0d5",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Appending to opacus_sync_demo.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile -a opacus_sync_demo.py\n",
        "\n",
        "import torch.multiprocessing as mp\n",
        "import torch\n",
        "\n",
        "world_size = torch.cuda.device_count()\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    mp.spawn(\n",
        "        launch,\n",
        "        args=(world_size,),\n",
        "        nprocs=world_size,\n",
        "        join=True\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a9c95972",
      "metadata": {},
      "source": [
        "When we run the code, notice that the gradients are not synchronised after `loss.backward()`, but only after `optimizer.step()`. For this example, we've set privacy parameters to effectively disable noise and clipping, so the synchronised gradient is indeed the average between individual worker's gradients."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "c4a8967c",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "05/13/2022 11:15:22:INFO:(rank=1) Gradient norm before optimizer.step(): 0.9924\n",
            "05/13/2022 11:15:22:INFO:(rank=1) Gradient sample before optimizer.step(): [-0.00525815 -0.01079952 -0.01051272]\n",
            "05/13/2022 11:15:22:INFO:(rank=0) Gradient norm before optimizer.step(): 1.7812\n",
            "05/13/2022 11:15:22:INFO:(rank=0) Gradient sample before optimizer.step(): [-0.0181896  -0.02559735 -0.02745825]\n",
            "05/13/2022 11:15:22:INFO:(rank=0) Gradient norm after optimizer.step(): 1.2387\n",
            "05/13/2022 11:15:22:INFO:(rank=1) Gradient norm after optimizer.step(): 1.2387\n",
            "05/13/2022 11:15:22:INFO:(rank=0) Gradient sample after optimizer.step(): [-0.01172432 -0.01819846 -0.01898623]\n",
            "05/13/2022 11:15:22:INFO:(rank=1) Gradient sample after optimizer.step(): [-0.01172432 -0.01819846 -0.01898623]\n"
          ]
        }
      ],
      "source": [
        "!python -W ignore opacus_sync_demo.py"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "78da9b43-d92a-436f-b000-4b633e85479a",
      "metadata": {},
      "source": [
        "## Cleanup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "1cf71582-6808-4fb8-b7b0-9afb462b75b8",
      "metadata": {},
      "outputs": [],
      "source": [
        "%%bash\n",
        "rm opacus_ddp_demo.py\n",
        "rm opacus_distributed_data_loader_demo.py\n",
        "rm opacus_sync_demo.py"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
